#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import logging
import os
import sys
import unittest
from unittest import mock

import paddle
from paddle.jit.dy2static import logging_utils
from paddle.utils import gast


class TestLoggingUtils(unittest.TestCase):
    def setUp(self):
        self.verbosity_level = 1
        self.code_level = 3
        self.translator_logger = logging_utils._TRANSLATOR_LOGGER

    def test_verbosity(self):
        paddle.jit.set_verbosity(None)
        os.environ[logging_utils.VERBOSITY_ENV_NAME] = '3'
        self.assertEqual(logging_utils.get_verbosity(), 3)

        paddle.jit.set_verbosity(self.verbosity_level)
        self.assertEqual(self.verbosity_level, logging_utils.get_verbosity())

        # String is not supported
        with self.assertRaises(TypeError):
            paddle.jit.set_verbosity("3")

        with self.assertRaises(TypeError):
            paddle.jit.set_verbosity(3.3)

    def test_also_to_stdout(self):
        logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = None
        self.assertEqual(
            logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False
        )

        paddle.jit.set_verbosity(also_to_stdout=False)
        self.assertEqual(
            logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False
        )

        logging_utils._TRANSLATOR_LOGGER.need_to_echo_node_to_stdout = None
        self.assertEqual(
            logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, False
        )

        paddle.jit.set_code_level(also_to_stdout=True)
        self.assertEqual(
            logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, True
        )

        with self.assertRaises(AssertionError):
            paddle.jit.set_verbosity(also_to_stdout=1)

        with self.assertRaises(AssertionError):
            paddle.jit.set_code_level(also_to_stdout=1)

    def test_set_code_level(self):
        paddle.jit.set_code_level(None)
        os.environ[logging_utils.CODE_LEVEL_ENV_NAME] = '2'
        self.assertEqual(logging_utils.get_code_level(), 2)

        paddle.jit.set_code_level(self.code_level)
        self.assertEqual(logging_utils.get_code_level(), self.code_level)

        paddle.jit.set_code_level(9)
        self.assertEqual(logging_utils.get_code_level(), 9)

        with self.assertRaises(TypeError):
            paddle.jit.set_code_level(3.3)

    def test_log_api(self):
        # test api for CI Converage
        logging_utils.set_verbosity(1, True)

        logging_utils.warn("warn")
        logging_utils.error("error")

        logging_utils.log(1, "log level 1")
        logging_utils.log(2, "log level 2")

        source_code = "x = 3"
        ast_code = gast.parse(source_code)
        logging_utils.set_code_level(1, True)
        logging_utils.log_transformed_code(1, ast_code, "TestTransformer")
        logging_utils.set_code_level(logging_utils.LOG_AllTransformer, True)
        logging_utils.log_transformed_code(
            logging_utils.LOG_AllTransformer, ast_code, "TestTransformer"
        )

    def test_log_message(self):
        stream = io.StringIO()
        log = self.translator_logger.logger
        stdout_handler = logging.StreamHandler(stream)
        log.addHandler(stdout_handler)

        warn_msg = "test_warn"
        error_msg = "test_error"
        log_msg_1 = "test_log_1"
        log_msg_2 = "test_log_2"

        with mock.patch.object(sys, 'stdout', stream):
            logging_utils.set_verbosity(1, False)
            logging_utils.warn(warn_msg)
            logging_utils.error(error_msg)
            logging_utils.log(1, log_msg_1)
            logging_utils.log(2, log_msg_2)

        result_msg = '\n'.join(
            [warn_msg, error_msg, "(Level 1) " + log_msg_1, ""]
        )
        self.assertEqual(result_msg, stream.getvalue())

    def test_log_transformed_code(self):
        source_code = "x = 3"
        ast_code = gast.parse(source_code)

        stream = io.StringIO()
        log = self.translator_logger.logger
        stdout_handler = logging.StreamHandler(stream)
        log.addHandler(stdout_handler)

        with mock.patch.object(sys, 'stdout', stream):
            paddle.jit.set_code_level(1)
            logging_utils.log_transformed_code(
                1, ast_code, "BasicApiTransformer"
            )

            paddle.jit.set_code_level()
            logging_utils.log_transformed_code(
                logging_utils.LOG_AllTransformer, ast_code, "All Transformers"
            )

        self.assertIn(source_code, stream.getvalue())


if __name__ == '__main__':
    unittest.main()
